Amazon SageMaker Debuggerのビルトインルールについて調べてみた
機械学習モデルを学習させる際には過学習や勾配消失等といった多種多様な問題の発生は避けられません。 Amazon SageMaker Debugger を使うことで、TensorFlow や MXNet、PyTorch、XGBoost を使ったモデル学習時の異常や問題を検出することができます。Amazon SageMaker Debuggerはルールにマッチしたかどうかによって評価し、学習時の異常や問題を検出します。
各ルールはトライアル(学習)の各種テンソルデータを参照し、ルールにマッチした場合にTrueを返します。ルールがTrueを返した場合には、ルールの評価処理によって例外が投げられます。CloudWatch EventsのSageMaker Training Job State Change
イベントを用いることで評価処理の結果をトリガーとして、通知などニアリアルタイムでの処理を実行することも可能です。
今回は Amazon SageMaker Debugger で事前に用意されているビルトインルール(Built-in Rules)にどういったものがあるか見てみます。
深層学習フレームワーク用
対応フレームワーク: TensorFlow/Apache MXNet/PyTorch
DeadRelu
ReLUが死んでる(dead)かどうかを検出します。 レイヤーの中で死んでいるReLUの割合が閾値以上の場合にTrueを返します。
パラメータ
tensor_regex, threshold_inactivity, threshold_layer
ExplodingTensor
発散を検出します。対象のテンソルにNaNやinfを検出した場合にTrueを返します。
パラメータ
collection_names, tensor_regex, only_nan
PoorWeightInitialization
学習の最初の数ステップの間の挙動から各重みの初期化が不十分かどうかを検出します。
- レイヤーごとの重みの分散の最小値と最大値の比が閾値を越えるとTrueを返します
- レイヤーごとの勾配分布の5パーセンタイルと95パーセンタイルの差の最小値が閾値以下ならTrueを返します
- ロスが指定したステップの間減らなければTrueを返します
パラメータ
activation_inputs_regex, threshold, distribution_range, patience, steps
SaturatedActivation
アークタンジェントもしくはシグモイドを用いたアクティベーションレイヤーが飽和しているかどうかを検出します。 飽和しているアクティベーションノードの割合が閾値を超えている場合にTrueを返します
パラメータ
collection_names, tensor_regex, threshold_tanh, threshold_sigmoid, threshold_inactivity, threshold_layer
VanishingGradient
勾配の消失を検出します。勾配の絶対値の平均が閾値より低くなった場合にTrueを返します。
パラメータ
threshold
WeightUpdateRatio
重み更新時の変化率の異常を検出します。重み更新時の変化率が閾値として設定した最大値もしくは最小値を超えた場合にTrueを返します。
パラメータ
num_steps, large_threshold, small_threshold, epsilon
深層学習フレームワークとXGBoost用
対応フレームワーク: TensorFlow/Apache MXNet/PyTorch/XGBoost
AllZero
テンソルの値が0になっている割合が高過ぎないかを検出します。 テンソルの値の0の割合が閾値を超えた場合にTrueを返します。
パラメータ
collection_names, tensor_regex, threshold
ClassImbalance
分類モデルにおいて、クラスの偏りを検出します。
- クラスごとのサンプル数が最大のものと最小のものの比が閾値を超えた場合にTrueを返します。
- 各クラスで誤って推論したデータの割合が閾値を超えた場合にTrueを返します。
- 例えば、Aというクラスでの誤推論の割合が閾値以下で、Bというクラスにおける誤推論の割合が閾値を超えた場合はTrueが返されます。
パラメータ
threshold_imbalance, threshold_misprediction, samples, argmax, labels_regex, predictions_regex
Confusion
分類モデルの混同行列の良さを評価し、問題を検出します。
- 混同行列の
対角要素
/対角要素の合計
が閾値より小さい場合にTrueを返します。 - 混同行列の
非対角要素
/列要素の合計
が閾値より大きい場合にTrueを返します。
パラメータ
category_no, labels, predictions, labels_collection, predictions_collection, min_diag, max_off_diag
LossNotDecreasing
ロスが減少しなくなったことを検出します。 指定したステップ間でのロスの減少割合が閾値より小さい場合にTrueを返します。
パラメータ
collection_names, tensor_regex, use_losses_collection, num_steps, diff_percent, mode
Overfit
学習データ(training)と検証データ(validation)のロスを比較することで、モデルの過剰適合(overfitting)を検出します。
学習データのロスの平均
に対する、検証データのロスの平均
と学習データのロスの平均
との差の割合が閾値を一定ステップ間超え続けた場合にTrueを返します。
パラメータ
tensor_regex, start_step, patience, ratio_threshold,
Overtraining
学習データ(training)や検証データ(validation)のロスの減少具合からモデルの過学習(overtraining)を検出します。
パラメータ
patience_train, patience_validation, delta
SimilarAcrossRuns
対象のトライアルと他のトライアルが似ているかどうかを検出します。
パラメータ
other_trial, collection_names, tensor_regex
TensorVariance
特定のテンソルの分散が高すぎたり、低すぎないかを検出します。
パラメータ
collection_names, tensor_regex, max_threshold, min_threshold
UnchangedTensor
テンソルがステップごとに変化がないことを検出します。テンソル同士の比較にはnumpy.allcloseが使われます。
パラメータ
collection_names, tensor_regex, num_steps, rtol, atol, equal_nan
深層学習アプリケーション用
CheckInputImages
サンプリングした入力画像の平均が0から閾値以上に離れているかどうかで、入力画像が正しく正規化されているかを検証します。
パラメータ
threshold_mean, threshold_samples, regex, channel
NLPSequenceRatio
自然言語処理において、特定のトークン(EOSやunknownなど)が入力トークンの中で占める割合が多すぎないかを検証します。
パラメータ
tensor_regex, token_values, token_thresholds_percent
XGBoost用
TreeDepth
学習によって作成された木の深さを測定します。
※ 具体的な記述がドキュメントにないので詳細は不明ですが、恐らく学習中の木の深さが閾値より浅いかどうかを検出するルールだと思います。
パラメータ
depth
さいごに
Amazon SageMaker Debugger の各種ビルトインルールの概要について紹介しました。用途に応じて適切なルールとパラメータを選択することで、モデルを学習する際に起こりうる様々な異常や問題を検出し、モデルの開発を効率化させることができそうです。